{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2ba1b4ef-3b96-4e7e-b5d0-155b839db73c",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/meta-llama/llama-recipes/blob/main/recipes/quickstart/agents/dlai/Functions_Tools_and_Agents_with_LangChain_L1_Function_Calling.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f91905c8-21ca-4d81-9614-b9c7344d08c3",
   "metadata": {},
   "source": [
    "This notebook ports the DeepLearning.AI short course [Functions, Tools and Agents with LangChain Lesson 1 OpenAI Function Calling](https://learn.deeplearning.ai/courses/functions-tools-agents-langchain/lesson/2/openai-function-calling) to using Llama 3. \n",
    "\n",
    "You should take the course before or after going through this notebook to have a deeper understanding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31bfe801-e24d-459b-8b3f-e91a34024368",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install groq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88659373-0deb-45eb-8934-0b02d70bd047",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "# Example dummy function hard coded to return the same weather\n",
    "# In production, this could be your backend API or an external API\n",
    "def get_current_weather(location, unit=\"fahrenheit\"):\n",
    "    \"\"\"Get the current weather in a given location\"\"\"\n",
    "    weather_info = {\n",
    "        \"location\": location,\n",
    "        \"temperature\": \"72\",\n",
    "        \"unit\": unit,\n",
    "        \"forecast\": [\"sunny\", \"windy\"],\n",
    "    }\n",
    "    return json.dumps(weather_info)\n",
    "\n",
    "known_functions = {\n",
    "    \"get_current_weather\": get_current_weather\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "359a584a-5b26-4497-afb4-72b63027edb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# https://console.groq.com/docs/tool-use#models\n",
    "# Groq API endpoints support tool use for programmatic execution of specified operations through requests with explicitly defined \n",
    "# operations. With tool use, Groq API model endpoints deliver structured JSON output that can be used to directly invoke functions.\n",
    "\n",
    "from groq import Groq\n",
    "import os\n",
    "import json\n",
    "\n",
    "client = Groq(api_key = 'your_groq_api_key' # get a free key at https://console.groq.com/keys')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cc17dc9-2827-4d39-a13d-a4ed5f53c8e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "functions = [\n",
    "    {\n",
    "        \"name\": \"get_current_weather\",\n",
    "        \"description\": \"Get the current weather in a given location\",\n",
    "        \"parameters\": {\n",
    "            \"type\": \"object\",\n",
    "            \"properties\": {\n",
    "                \"location\": {\n",
    "                    \"type\": \"string\",\n",
    "                    \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
    "                },\n",
    "                \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n",
    "            },\n",
    "            \"required\": [\"location\"],\n",
    "        },\n",
    "    }\n",
    "]\n",
    "\n",
    "tools = [\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"get_current_weather\",\n",
    "            \"description\": \"Get the current weather in a given location\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"location\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
    "                    },\n",
    "                    \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n",
    "                },\n",
    "                \"required\": [\"location\"],\n",
    "            },\n",
    "        }\n",
    "    }\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a64d28e-b169-4855-b3c2-d6722c56394c",
   "metadata": {},
   "outputs": [],
   "source": [
    "messages = [\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": \"What's the weather like in Boston?\"\n",
    "    }\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a733c1e1-c7f2-4707-b1be-02179df0abc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "response = client.chat.completions.create(\n",
    "    model=\"llama3-70b-8192\",\n",
    "    messages=messages,\n",
    "    functions=functions,\n",
    "    #tools=tools, # you can also replace functions with tools, as specified in https://console.groq.com/docs/tool-use \n",
    "    max_tokens=4096, \n",
    "    temperature=0\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9058073d-cf91-4747-9860-7e2a1d774acf",
   "metadata": {},
   "outputs": [],
   "source": [
    "response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffd4ed64-0436-499e-a7e5-4224833b72f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "response_message = response.choices[0].message\n",
    "response_message"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5458444a-a448-4c5b-b06c-47ab6cd25626",
   "metadata": {},
   "outputs": [],
   "source": [
    "response_message.content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c669a048-1a3e-43e9-b98f-d0b6a3a0f4c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "response_message.function_call"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27f3de5d-5110-486e-8b07-5086939d364d",
   "metadata": {},
   "outputs": [],
   "source": [
    "json.loads(response_message.function_call.arguments)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b69e6497-9e68-47d4-99ae-d45db6c1a8db",
   "metadata": {},
   "outputs": [],
   "source": [
    "args = json.loads(response_message.function_call.arguments)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f41a7162-9ce8-4353-827b-f6f3bb278218",
   "metadata": {},
   "outputs": [],
   "source": [
    "get_current_weather(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb0546f2-de55-417a-9b38-66787b673fb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "function_call = response.choices[0].message.function_call\n",
    "function_call"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0dd1fcf0-7105-4cad-82b5-22ce3b24fc07",
   "metadata": {},
   "outputs": [],
   "source": [
    "function_call.name, function_call.arguments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6d58efe-0ada-48a2-b12b-6bff948a2983",
   "metadata": {},
   "outputs": [],
   "source": [
    "# by defining and using known_functions, we can programatically call function\n",
    "function_response = known_functions[function_call.name](function_call.arguments)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cee6ca19-6924-4a7b-ba7f-7b1a33344ca0",
   "metadata": {},
   "outputs": [],
   "source": [
    "function_response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8480be29-3326-4d95-8742-dff976a7ab2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# add the message returned by tool and query LLM again to get final answer\n",
    "messages.append(\n",
    "{\n",
    "    \"role\": \"function\",\n",
    "    \"name\": function_call.name,\n",
    "    \"content\": function_response,\n",
    "}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a42e35f-c601-4c14-8de5-bdbba01dc622",
   "metadata": {},
   "outputs": [],
   "source": [
    "messages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9a2d1ee-9e41-480a-a5cc-62c273d3a179",
   "metadata": {},
   "outputs": [],
   "source": [
    "response = client.chat.completions.create(\n",
    "    model=\"llama3-70b-8192\",\n",
    "    messages=messages,\n",
    "    temperature=0\n",
    ")\n",
    "\n",
    "response.choices[0].message.content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54019c56-11cf-465a-a440-296081adee93",
   "metadata": {},
   "outputs": [],
   "source": [
    "messages = [\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": \"hi!\",\n",
    "    }\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "922724ec-1744-4ccf-9a86-5f1823dce0e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "response = client.chat.completions.create(\n",
    "    model=\"llama3-70b-8192\",\n",
    "    messages=messages,\n",
    "    functions=functions,\n",
    "    function_call=\"none\", # default is auto (let LLM decide if using function call or not. can also be none, or a dict {{\"name\": \"func_name\"}\n",
    "    temperature=0\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04c3152a-f51b-45cb-a27c-0672337520b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "582fac7c-0de7-420c-8150-038e74be4b9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "response_message = response.choices[0].message\n",
    "response_message"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3d62357-04c9-459c-b36a-89e58444ea63",
   "metadata": {},
   "outputs": [],
   "source": [
    "messages = [\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": \"hi!\",\n",
    "    }\n",
    "]\n",
    "response = client.chat.completions.create(\n",
    "    model=\"llama3-70b-8192\",\n",
    "    messages=messages,\n",
    "    functions=functions,\n",
    "    function_call=\"auto\", # default is auto (let LLM decide if using function call or not. can also be none, or a dict {{\"name\": \"func_name\"}\n",
    "    temperature=0\n",
    ")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "632df69d-85bc-4e44-814c-7c1d2fe97228",
   "metadata": {},
   "outputs": [],
   "source": [
    "messages = [\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": \"hi!\",\n",
    "    }\n",
    "]\n",
    "response = client.chat.completions.create(\n",
    "    model=\"llama3-70b-8192\",\n",
    "    messages=messages,\n",
    "    functions=functions,\n",
    "    function_call=\"none\", # default is auto (let LLM decide if using function call or not. can also be none, or a dict {{\"name\": \"func_name\"}\n",
    "    temperature=0\n",
    ")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c773ab17-a620-44eb-877f-9e0bc23fb00b",
   "metadata": {},
   "outputs": [],
   "source": [
    "messages = [\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": \"What's the weather in Boston?\",\n",
    "    }\n",
    "]\n",
    "response = client.chat.completions.create(\n",
    "    model=\"llama3-70b-8192\",\n",
    "    messages=messages,\n",
    "    functions=functions,\n",
    "    function_call=\"none\", # default is auto (let LLM decide if using function call or not. can also be none, or a dict {{\"name\": \"func_name\"}\n",
    "    temperature=0\n",
    ")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4a8ee80-83ae-4189-837c-54bb9c93c315",
   "metadata": {},
   "outputs": [],
   "source": [
    "messages = [\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": \"hi!\",\n",
    "    }\n",
    "]\n",
    "response = client.chat.completions.create(\n",
    "    model=\"llama3-70b-8192\",\n",
    "    messages=messages,\n",
    "    functions=functions,\n",
    "    function_call={\"name\": \"get_current_weather\"}, # default is auto (let LLM decide if using function call or not. can also be none, or a dict {{\"name\": \"func_name\"}\n",
    "    temperature=0\n",
    ")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "daa5801a-2e71-4630-a8cd-7e84d1214f51",
   "metadata": {},
   "outputs": [],
   "source": [
    "messages = [\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": \"What's the weather like in Boston!\",\n",
    "    }\n",
    "]\n",
    "response = client.chat.completions.create(\n",
    "    model=\"llama3-70b-8192\",\n",
    "    messages=messages,\n",
    "    functions=functions,\n",
    "    function_call={\"name\": \"get_current_weather\"}, # default is auto (let LLM decide if using function call or not. can also be none, or a dict {{\"name\": \"func_name\"}\n",
    "    temperature=0\n",
    ")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de5924d4-4225-48d1-a390-e44f3167d547",
   "metadata": {},
   "outputs": [],
   "source": [
    "function_call = response.choices[0].message.function_call\n",
    "function_call.name, function_call.arguments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb9f3340-b905-47f3-a478-cf3d786faa1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "args = json.loads(response.choices[0].message.function_call.arguments)\n",
    "observation = known_functions[function_call.name](args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c31e9b5-99ed-46f3-8849-133c71ea87d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "observation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b73c550e-5aa2-49de-8422-0c3e706f1df4",
   "metadata": {},
   "outputs": [],
   "source": [
    "messages.append(\n",
    "        {\n",
    "            \"role\": \"function\",\n",
    "            \"name\": function_call.name,\n",
    "            \"content\": observation,\n",
    "        }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c60302f1-07e2-4f22-bd60-b54e1ea2e3db",
   "metadata": {},
   "outputs": [],
   "source": [
    "messages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a35f7f3d-4e39-4744-b5e3-2065e67eea28",
   "metadata": {},
   "outputs": [],
   "source": [
    "response = client.chat.completions.create(\n",
    "    model=\"llama3-70b-8192\",\n",
    "    messages=messages,\n",
    ")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d4745e1-0477-4b6b-84de-9c82e0bc2452",
   "metadata": {},
   "outputs": [],
   "source": [
    "response.choices[0].message.content"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
